from PIL import Image
import json
import time
import torch
import asyncio
import aiohttp
from copy import deepcopy
import torch.nn as nn
import torch.nn.functional as F
from vllm.distributed.parallel_state import destroy_model_parallel
import time
import concurrent.futures
import requests
import json
import time
import concurrent.futures
import torch
import deepspeed
import gc
import ray
import base64
import time
import re
from transformers import Gemma3ForConditionalGeneration, AutoTokenizer, AutoProcessor, BitsAndBytesConfig
from tti.misc import merge_dicts
from tti.models.claude_prompts import SYSTEM_PROMPT
import os
from vllm import LLM, SamplingParams
import threading
import logging        # For logging messages and errors


import json
import time
import torch
import asyncio
import aiohttp
from copy import deepcopy


class GemmaVllmAgent(nn.Module):
    def __init__(self, policy_lm, config=None, 
                 use_lora=False, use_anyres=False, vllm_tensor_parallel_size=1):
        """
        Create a GemmaAgent that wraps a gemma3 model.
        
        Parameters:
            policy_lm (str): HuggingFace model identifier (e.g. "Qwen/Qwen2.5-VL-7B-Instruct")
            config: A config object with attributes such as temperature and max_new_tokens.
            use_q4, use_lora, use_anyres: Additional flags (kept for signature compatibility; not used here).
        """
        super(GemmaVllmAgent, self).__init__()
        self.model = None
        self.train_model = None
        
        self.policy_lm = policy_lm
        self.vllm_tensor_parallel_size = vllm_tensor_parallel_size
        self.config = config
        self.evaluator_ip_address = config.evaluator_ip_address
        self.temperature = config.temperature if config is not None else 1.0
        self.max_new_tokens = config.max_new_tokens if config is not None else 128
        self.use_anyres = use_anyres
        self.use_lora = use_lora
        self.mode = None
        self.max_prompt_length = 16000
        self.updated_model_path = None

        # top_k = 64,
        self.sampling_params = SamplingParams(temperature=self.temperature, top_p = 0.95, min_p = 0.0,
                    max_tokens=self.max_new_tokens, logprobs=True,
                    stop_token_ids=[1, 106])

        self.tokenizer = AutoTokenizer.from_pretrained(policy_lm)
        self.processor = AutoProcessor.from_pretrained(policy_lm)
            
    def enter_infer_mode(self, updated_model_path=None):
        """
        Transition from training to inference mode by properly cleaning up resources
        and initializing vLLM.
        """
        # Skip if already in infer mode
        if self.mode == "infer":
            return
            
        previous_mode = self.mode        
        
        # # Clean up DeepSpeed-related resources if coming from training
        # if previous_mode == "train":
        #     # Clean up model resources
        #     if hasattr(self, "model") and self.model is not None:
        #         if hasattr(self.model, "module"):
        #             del self.model.module
                    
        #         # Delete model and clear CUDA cache
        #         del self.model
        #         self.model = None
                
        #     # Force garbage collection and CUDA cache clearing
        #     gc.collect()
        #     torch.cuda.empty_cache()
        #     print("Successfully cleaned up training resources")
        
        # Initialize vLLM only on rank 0
        self.mode = "infer"  # Update mode before vLLM init to prevent race conditions

        if deepspeed.comm.get_rank() == 0:
            print(f"Transitioning to inference mode from {previous_mode}")
            # Kill any existing Ray processes first
            try:
                os.system("pkill -9 -f ray")
                os.system("pkill -9 -f vllm")
                time.sleep(3)  # Give time for processes to terminate
            except Exception as e:
                print(f"Warning when killing processes: {e}")
            
            # Set environment variables for vLLM
            os.environ["VLLM_USE_V1"] = "1"
            os.environ["VLLM_WORKER_USE_RAY"] = "0"
            os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
            
            # Attempt to initialize vLLM with conservative settings
            try:
                
                model_path = self.policy_lm if updated_model_path is None else updated_model_path   
                self.updated_model_path = model_path
                self.model = LLM(
                    model=model_path,
                    max_model_len=self.max_prompt_length,
                    tensor_parallel_size=self.vllm_tensor_parallel_size, 
                    gpu_memory_utilization=0.8,
                    enforce_eager=True,  # Avoid CUDA graphs
                    disable_log_stats=True,
                    limit_mm_per_prompt={"image": 4}
                )
                print(f"vLLM initialization successful from {model_path}")
            except Exception as e:
                print(f"Error initializing vLLM: {e}")
                import traceback
                traceback.print_exc()
                # Don't raise - continue with a null model
                self.model = None
        else:
            # Non-rank 0 processes just update their state
            self.model = None
        
        # print(f"Successfully transitioned to inference mode from {previous_mode}")

    def enter_train_mode(self, updated_model_path=None):
        """
        Transition to training mode with robust cleanup of inference resources.
        """
        previous_mode = self.mode
        

        if previous_mode == "train":
            return
        
        # Step 1: Clean up vLLM/Ray resources
        if previous_mode == "infer" and deepspeed.comm.get_rank() == 0:
            print(f"Transitioning to training mode from {previous_mode}")
            # Clean up vLLM resources if they exist
            if hasattr(self, "model") and self.model is not None:
                try:
                    destroy_model_parallel()
                    try:
                        del self.model.llm_engine.driver_worker
                    except Exception as e:
                        print(f"Warning deleting driver_worker")
                    del self.model
                    self.model = None
                    print("Cleaned up vLLM model")
                except Exception as e:
                    print(f"Warning during vLLM cleanup: {e}")
                    self.model = None
            
            # Kill any Ray processes
            try:
                if ray.is_initialized():
                    ray.shutdown()
                os.system("pkill -f ray")
                os.system("pkill -f vllm")
                time.sleep(3)  # Give processes time to terminate
            except Exception as e:
                print(f"Warning during Ray cleanup: {e}")
        
        # Clear GPU memory on all ranks
        gc.collect()
        torch.cuda.empty_cache()
        
        # Step 2: Update mode to train
        self.mode = "train"
        
        # Step 3: Load the HF model on all ranks
        model_path = self.policy_lm if updated_model_path is None else updated_model_path
        rank = deepspeed.comm.get_rank()
        # print(f"Loading HF model from {model_path} for training on rank {rank}")

        if self.train_model is None:
            try:
                # Load with minimal memory usage
                self.train_model = Gemma3ForConditionalGeneration.from_pretrained(
                    model_path,
                    torch_dtype=torch.bfloat16,
                    device_map="cpu",
                    low_cpu_mem_usage=True,
                )
                print(f"Successfully loaded HF model for training on rank {rank}")
            except Exception as e:
                print(f"Error loading model on rank {rank}: {e}")
                import traceback
                traceback.print_exc()
                self.train_model = None
        else:
            print(f"Use existing on rank {rank}")
        
        # print("Training mode transition complete")
    
    # 1. Update get_gemma_vllm_prompts in GemmaVllmAgent class
    def get_gemma_hf_prompts(self, unprocessed_observation, system_prompt=None, transfer_image_to_base64=True):
        observation = deepcopy(unprocessed_observation)
        # result = []
        
        # Process each item in observation
        for obs_idx, obs in enumerate(observation):
            # # Skip invalid observation types
            # if isinstance(obs, int):
            #     print(f"Warning: observation item {obs_idx} is an integer: {obs}")
            #     continue
            # if not isinstance(obs, list):
            #     print(f"Warning: observation item {obs_idx} is not a list: {type(obs)}")
            #     continue
            
            # Process each message in observation
            # formatted_rounds = []
            for round_idx, round_copy in enumerate(obs):
                # Skip invalid rounds
                # if not isinstance(round_item, dict) or 'content' not in round_item:
                #     print(f"Warning: round {round_idx} is invalid: {type(round_item)}")
                #     continue
                
                # # Create a copy to modify
                # round_copy = deepcopy(round_item)
                
                # Format string content to proper structure
                # if isinstance(round_copy['content'], str):
                #     round_copy['content'] = [{"type": "text", "text": round_copy['content']}]
                # elif not isinstance(round_copy['content'], list):
                #     print(f"Warning: content is not a list or string: {type(round_copy['content'])}")
                #     continue
                if round_copy['role'] == "user":
                # Process each content item
                # valid_content = []
                    for i in range(len(round_copy['content'])):
                        content_item = round_copy['content'][i]
                        
                        # # Ensure content item is a dict
                        # if not isinstance(content_item, dict):
                        #     print(f"Warning: content item {i} is not a dict: {type(content_item)}")
                        #     content_item = {"type": "text", "text": str(content_item)}
                        
                        # # Check for empty text with detailed logging
                        # if content_item.get('type') == 'text' and not content_item.get('text'):
                        #     print(f"Warning: text content item {i} has empty text")
                        #     print(f"Full content item: {content_item}")
                        #     print(f"Full round: {round_copy}")
                        #     print(f"Round role: {round_copy.get('role', 'unknown')}")
                        #     print(f"Message index: {round_idx} in observation {obs_idx}")
                            
                        #     # Use space to avoid empty text warning
                        #     content_item['text'] = " "
                        
                        # Process source if present
                        if 'source' in content_item:
                            # Handle image paths
                            if content_item['source'].get('type') == 'path':
                                img_path = content_item['source']['path']
                                observation[obs_idx][round_idx]['content'][i]['type'] = "image"
                                if transfer_image_to_base64:
                                    with open(img_path, "rb") as image_file:
                                        img_data = base64.b64encode(image_file.read()).decode('utf-8')
                                    observation[obs_idx][round_idx]['content'][i]['image'] = f"data:image/png;base64,{img_data}"
                                else:
                                    observation[obs_idx][round_idx]['content'][i]['image'] = img_path
                                del observation[obs_idx][round_idx]['content'][i]['source']
                                
                            # Handle base64 data
                            elif 'data' in content_item['source']:
                                observation[obs_idx][round_idx]['content'][i]['type'] = "image_url"
                                observation[obs_idx][round_idx]['content'][i]['image_url'] = {"url": f"data:image/png;base64,{content_item['source']['data']}"}
                                del  observation[obs_idx][round_idx]['content'][i]['source']
                    
                    # valid_content.append(content_item)
                
                # Update content with validated items
                # round_copy['content'] = valid_content
                # formatted_rounds.append(round_copy)
            
            # Add system prompt if we have valid rounds
            # if formatted_rounds:
            #     formatted_rounds.insert(0, {
            #         "role": "system",
            #         "content": [
            #             {
            #                 "type": "text",
            #                 "text": system_prompt
            #             }
            #         ]
            #     })
            #     result.append(formatted_rounds)
        
        return observation
        
    # 1. Update get_gemma_vllm_prompts in GemmaVllmAgent class
    def get_gemma_vllm_prompts(self, unprocessed_observation, system_prompt=None, transfer_image_to_base64=True):
        
        observation = deepcopy(unprocessed_observation)
        # result = []
        
        # Process each item in observation
        for obs_idx, obs in enumerate(observation):
            # # Skip invalid observation types
            # if isinstance(obs, int):
            #     print(f"Warning: observation item {obs_idx} is an integer: {obs}")
            #     continue
            # if not isinstance(obs, list):
            #     print(f"Warning: observation item {obs_idx} is not a list: {type(obs)}")
            #     continue
            
            # Process each message in observation
            # formatted_rounds = []
            for round_idx, round_copy in enumerate(obs):
                # Skip invalid rounds
                # if not isinstance(round_item, dict) or 'content' not in round_item:
                #     print(f"Warning: round {round_idx} is invalid: {type(round_item)}")
                #     continue
                
                # # Create a copy to modify
                # round_copy = deepcopy(round_item)
                
                # Format string content to proper structure
                # if isinstance(round_copy['content'], str):
                #     round_copy['content'] = [{"type": "text", "text": round_copy['content']}]
                # elif not isinstance(round_copy['content'], list):
                #     print(f"Warning: content is not a list or string: {type(round_copy['content'])}")
                #     continue
                if round_copy['role'] == "user":
                # Process each content item
                # valid_content = []
                    for i in range(len(round_copy['content'])):
                        content_item = round_copy['content'][i]
                        
                        # # Ensure content item is a dict
                        # if not isinstance(content_item, dict):
                        #     print(f"Warning: content item {i} is not a dict: {type(content_item)}")
                        #     content_item = {"type": "text", "text": str(content_item)}
                        
                        # # Check for empty text with detailed logging
                        # if content_item.get('type') == 'text' and not content_item.get('text'):
                        #     print(f"Warning: text content item {i} has empty text")
                        #     print(f"Full content item: {content_item}")
                        #     print(f"Full round: {round_copy}")
                        #     print(f"Round role: {round_copy.get('role', 'unknown')}")
                        #     print(f"Message index: {round_idx} in observation {obs_idx}")
                            
                        #     # Use space to avoid empty text warning
                        #     content_item['text'] = " "
                        
                        # Process source if present
                        if 'source' in content_item:
                            # Handle image paths
                            if content_item['source'].get('type') == 'path':
                                img_path = content_item['source']['path']
                                observation[obs_idx][round_idx]['content'][i]['type'] = "image_url"
                                if transfer_image_to_base64:
                                    with open(img_path, "rb") as image_file:
                                        img_data = base64.b64encode(image_file.read()).decode('utf-8')
                                    observation[obs_idx][round_idx]['content'][i]['image_url'] = {"url": f"data:image/png;base64,{img_data}"}
                                else:
                                    observation[obs_idx][round_idx]['content'][i]['image_url'] = {"url": img_path}
                                del observation[obs_idx][round_idx]['content'][i]['source']
                                
                            # Handle base64 data
                            elif 'data' in content_item['source']:
                                observation[obs_idx][round_idx]['content'][i]['type'] = "image_url"
                                observation[obs_idx][round_idx]['content'][i]['image_url'] = {"url": f"data:image/png;base64,{content_item['source']['data']}"}
                                del  observation[obs_idx][round_idx]['content'][i]['source']
                    
                    # valid_content.append(content_item)
                
                # Update content with validated items
                # round_copy['content'] = valid_content
                # formatted_rounds.append(round_copy)
            
            # Add system prompt if we have valid rounds
            # if formatted_rounds:
            #     formatted_rounds.insert(0, {
            #         "role": "system",
            #         "content": [
            #             {
            #                 "type": "text",
            #                 "text": system_prompt
            #             }
            #         ]
            #     })
            #     result.append(formatted_rounds)
        
        return observation

    # 2. Update get_single_gemma_hf_prompts method
    def get_single_gemma_hf_prompts(self, observation):
        observation = deepcopy(observation)
        for round in observation:
            if type(round['content']) == str:
                round['content'] = [{"type": "text", "text": round['content']}]
            for i in range(len(round['content'])):
                if 'source' in round['content'][i]:
                    # Check if the source type is 'path'
                    if round['content'][i]['source'].get('type') == 'path':
                        # Get the image path
                        img_path = round['content'][i]['source']['path']
                        # Set the image field directly with the path
                        round['content'][i]['type'] = "image"
                        round['content'][i]['image'] = img_path
                        del round['content'][i]['source']
                    # Handle the case where source contains base64 data (original behavior)
                    elif 'data' in round['content'][i]['source']:
                        round['content'][i]['type'] = "image"
                        round['content'][i]['image'] = round['content'][i]['source']['data']
                        del round['content'][i]['source']
        # for each obs, prepend a system prompt
        observation.insert(0, {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": SYSTEM_PROMPT
                }
            ]
        })
        return observation

    # 3. Update get_single_gemma_vllm_prompts method
    def get_single_gemma_vllm_prompts(self, observation, evaluator_prompt):
        observation = deepcopy(observation)
        for round in observation:
            if type(round['content']) == str:
                round['content'] = [{"type": "text", "text": round['content']}]
            else:
                for i in range(len(round['content'])):
                    if 'source' in round['content'][i]:
                        # Check if the source type is 'path'
                        if round['content'][i]['source'].get('type') == 'path':
                            # Get the image path
                            img_path = round['content'][i]['source']['path']
                            # Read and encode the image
                            with open(img_path, "rb") as image_file:
                                img_data = base64.b64encode(image_file.read()).decode('utf-8')
                            # Update the format for vLLM
                            round['content'][i]['type'] = "image_url"
                            round['content'][i]['image_url'] = {"url": f"data:image/png;base64,{img_data}"}
                            del round['content'][i]['source']
                        # Handle the case where source contains base64 data
                        elif 'data' in round['content'][i]['source']:
                            round['content'][i]['type'] = "image_url"
                            round['content'][i]['image_url'] = {"url": f"data:image/png;base64,{round['content'][i]['source']['data']}"}
                            del round['content'][i]['source']
        # Prepend a system prompt
        observation.insert(0, {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": evaluator_prompt
                }
            ]
        })
        return observation

    
    def convert_single_vllm_msg_to_hf(self, messages):
        messages = deepcopy(messages)
        for round in messages:
            if len(round['content']) == 2 and round['content'][1]['type'] == "image_url":
                round['content'][1]['type'] = "image"
                round['content'][1]['image'] = round['content'][1]['image_url']['url']
                del round['content'][1]['image_url']
        return messages
        
    def get_action(self, unprocessed_messages):
        """
        Get action from the model with improved error handling and logging
        """
        gen_ready = threading.Event()
        gen_result = [None]
        gen_error = [None]

        def generate_thread():
            with torch.no_grad():
                try:
                    # Process messages with better error handling
                    messages = self.get_gemma_vllm_prompts(unprocessed_messages)
                    bsz = len(messages)
                    
                    # Log message processing results
                    print(f"Processed {len(messages)} message lists")
                                            
                    # Try to get a response from the model
                    try:
                        response = self.model.chat(
                            messages=messages,
                            sampling_params=self.sampling_params,
                            chat_template=None,
                        )
                        
                    except Exception as e:
                        print("error in vLLM chat", e)
                        # might be too long output error
                        outputs = []
                        probs = []
                        trunc_messages = []
                        for msg_idx, msg in enumerate(messages):
                            total_len = 0
                            for msg_ in msg:
                                total_len += len(self.processor.tokenizer([msg_['content'][0]['text']])["input_ids"][0])
    
                            if total_len > self.max_prompt_length * 0.8:
                                
                                msg[-1]['content'][0]['text'] = msg[1]['content'][0]['text'][:re.search("Task: ", msg[1]['content'][0]['text']).start()] + msg[-1]['content'][0]['text'][re.search("Task: ", msg[-1]['content'][0]['text']).start():]
                                trunc_input_ids = self.processor.tokenizer([msg[-1]['content'][0]['text']])["input_ids"][0][:int(self.max_prompt_length * 0.8)]
                                trunc_input = self.processor.tokenizer.batch_decode([trunc_input_ids])[0]
                                msg[-1]['content'][0]['text'] = msg[-1]['content'][0]['text'][:len(trunc_input)]
                                trunc_messages.append([(msg[0])] + [(msg[-1])])
                                
                            else:
                                trunc_messages.append((msg))
    
                        response = self.model.chat(
                        messages=trunc_messages,
                        sampling_params=self.sampling_params,
                            chat_template=None,
                    )
                        
                    print(f"Received {len(response)} responses from model")
                    outputs = []
                    probs = []
                    for i, out in enumerate(response):
                        try:
                            
                            generated_text = out.outputs[0].text
                            prob = out.outputs[0].logprobs
                            p = 0
                            for tok_prob in prob:
                                for k, v in tok_prob.items():
                                    p += v.logprob
                                
                            outputs.append(generated_text)
                            probs.append(p / len(prob))
    
                            # print(f"Successfully processed and validated output {i}")
                        except Exception as e:
                            print(f"Error processing output {i}: {e}")
                            outputs.append("Thought: There was an error processing the model output.\nAction: Wait")
                            probs.append(-100)
    
                    if len(outputs) > bsz:
                        outputs = outputs[-bsz:]
                        probs = probs[-bsz:]

                    gen_result[0] = (outputs, probs)
                    gen_ready.set()
                    return 
                        
                except Exception as e:
                    print(f"CRITICAL ERROR in get_action: {e}")
                    import traceback
                    traceback.print_exc()
                    # Return a fallback response that will never be empty
                    gen_result[0] = (["Thought: There was a critical error in the model.\nAction: Wait"] * len(messages), [0] * len(messages))
                    gen_ready.set()
                    return
        thread = threading.Thread(target=generate_thread)
        thread.daemon = True
        thread.start()
        
        
        # Wait for generation to complete or timeout
        timeout_seconds = 300
        if not gen_ready.wait(timeout_seconds):
            logging.error(f"vLLM generation timed out after {timeout_seconds} seconds")

            logging.info("Attempting to reinitialize vLLM model after timeout")
            self.enter_train_mode()
            self.enter_infer_mode(self.updated_model_path)
            logging.info("Retrying generation with reinitialized model")
            try:
                generate_thread()
                return gen_result[0][0], gen_result[0][1]
            except Exception as e:
                logging.error(f"Retry generation failed: {e}")
                return ["Thought: There was a critical error in the model.\nAction: Wait"] * len(messages), [0] * len(messages)

        return gen_result[0][0], gen_result[0][1]

    def get_reward(self, messages_list, evaluator_prompt):
        """
        Get reward by sending batched evaluation requests to the online vLLM evaluator service.
        
        Parameters:
            messages_list (list): List of message exchanges to evaluate
            evaluator_prompt (str): System prompt for the evaluator
            
        Returns:
            list: List of evaluator responses, one per input message
        """
        with torch.no_grad():
            start_time = time.time()
            
            # Check if an evaluator endpoint is provided
            if not hasattr(self, 'evaluator_ip_address') or not self.evaluator_ip_address:
                print("ERROR: No evaluator_ip_address specified.")
                return ["ERROR: No evaluator endpoint configured"] * (len(messages_list) if messages_list else 1)
            
            # Filter out any invalid messages
            valid_messages = []
            for idx, msg in enumerate(messages_list):
                if isinstance(msg, int):
                    print(f"Warning: message {idx} is an integer: {msg}")
                    continue
                valid_messages.append(msg)
            
            if not valid_messages:
                print("No valid messages to process.")
                return []
            
            try:
                # Format messages for the online evaluator (add system prompt, process images, etc.)
                processed_messages = self.get_gemma_vllm_prompts(valid_messages, system_prompt=evaluator_prompt)
                
                if not processed_messages:
                    print("No valid messages after processing.")
                    return ["ERROR"] * len(valid_messages)
                
                # Print message statistics for debugging
                num_messages = len(processed_messages)
                total_images = 0
                for msg_list in processed_messages:
                    for msg in msg_list:
                        if 'content' in msg:
                            for content_item in msg['content']:
                                if isinstance(content_item, dict) and content_item.get('type') == 'image_url':
                                    total_images += 1
                
                print(f"Sending batch with {num_messages} messages and {total_images} images to evaluator")
                
                # Prepare the data to send to the evaluator service
                request_data = {
                    "messages": processed_messages,
                    "system_prompt": evaluator_prompt
                }
                
                # Send the data to the online evaluator
                import requests
                import json
                
                # Format the URL from the evaluator_ip_address
                server_url = self.evaluator_ip_address
                if not server_url.startswith('http'):
                    # Check if it includes a port number
                    if ':' in server_url:
                        server_url = f"http://{server_url}"
                    else:
                        # Default to port 7860 (Gradio's default)
                        server_url = f"http://{server_url}:7860"
                
                # Make sure the URL ends with /api/predict
                if not server_url.endswith('/api/predict'):
                    if server_url.endswith('/'):
                        server_url += 'api/predict'
                    else:
                        server_url += '/api/predict'
                        
                print(f"Connecting to evaluator at: {server_url}")
                
                # Serialize the request data
                serialized_data = json.dumps(request_data)
                
                # Implement retry logic
                max_retries = 5
                base_delay = 2  # seconds
                
                for retry in range(max_retries):
                    try:
                        # Set a timeout to avoid hanging indefinitely
                        timeout = 600  # 10 minutes should be enough for large batches
                        
                        # Send the request
                        response = requests.post(
                            url=server_url,
                            json={"data": [serialized_data]},
                            timeout=timeout
                        )
                        
                        # Check if the request was successful
                        response.raise_for_status()
                        
                        # Parse the response
                        response_json = response.json()
                        
                        if "data" not in response_json:
                            raise ValueError(f"Unexpected response format: {response_json}")
                        
                        # Parse the response data
                        response_data = response_json["data"]
                        
                        # Handle different response data types
                        if isinstance(response_data, str):
                            try:
                                # Try to parse as JSON
                                output_data = json.loads(response_data)
                            except json.JSONDecodeError:
                                # If not valid JSON, use as is
                                output_data = {"outputs": [response_data]}
                        elif isinstance(response_data, list):
                            # If it's a list, check the first item
                            if response_data and isinstance(response_data[0], str):
                                try:
                                    output_data = json.loads(response_data[0])
                                except json.JSONDecodeError:
                                    output_data = {"outputs": response_data}
                            else:
                                output_data = {"outputs": response_data}
                        elif isinstance(response_data, dict):
                            # If it's already a dict, use it directly
                            output_data = response_data
                        else:
                            output_data = {"outputs": [str(response_data)]}
                        
                        # Check for errors in the output data
                        if "error" in output_data:
                            raise ValueError(f"Error from evaluator: {output_data['error']}")
                        
                        # Extract the outputs
                        outputs = output_data.get("outputs", [])
                        
                        # Ensure we have the right number of outputs
                        if len(outputs) != len(processed_messages):
                            print(f"WARNING: Got {len(outputs)} outputs for {len(processed_messages)} messages")
                            # Pad with errors if needed
                            if len(outputs) < len(processed_messages):
                                outputs.extend(["ERROR: Missing response"] * (len(processed_messages) - len(outputs)))
                            # Truncate if we got too many
                            elif len(outputs) > len(processed_messages):
                                outputs = outputs[:len(processed_messages)]
                        
                        processing_time = time.time() - start_time
                        print(f"Evaluation completed in {processing_time:.2f} seconds")
                        
                        return outputs
                        
                    except (requests.exceptions.RequestException, json.JSONDecodeError, ValueError) as e:
                        if retry < max_retries - 1:
                            delay = base_delay * (2 ** retry)
                            print(f"Request failed (attempt {retry+1}/{max_retries}): {e}")
                            print(f"Retrying in {delay} seconds...")
                            time.sleep(delay)
                        else:
                            print(f"All retries failed: {e}")
                            return ["ERROR"] * len(valid_messages)
                
                # This should not be reached, but just in case
                return ["ERROR: All retries failed"] * len(valid_messages)
                    
            except Exception as e:
                print(f"Error in get_reward: {e}")
                import traceback
                traceback.print_exc()
                return ["ERROR"] * len(valid_messages)
    
    def label_trajectories(self, trajs, evaluator_prompt):
        messages_list = []
        for i in range(len(trajs)):
            if len(trajs[i]) > 0:
                if trajs[i][-1].get('eval_prompt') != "":
                    # Check the type of eval_prompt before adding to messages_list
                    eval_prompt = trajs[i][-1]['eval_prompt']
                    # Skip if eval_prompt is an integer or invalid type
                    if isinstance(eval_prompt, int):
                        print(f"Warning: eval_prompt is an integer: {eval_prompt}, skipping")
                        continue
                    # Only add if it's a valid message list
                    if isinstance(eval_prompt, list):
                        messages_list.append(eval_prompt)
                    else:
                        print(f"Warning: unexpected eval_prompt type: {type(eval_prompt)}")
                        continue
        
        # Skip reward calculation if no valid messages
        if not messages_list:
            print("No valid evaluation prompts found.")
            return trajs
        
        # Get responses for valid messages
        try:
            response_list = self.get_reward(messages_list, evaluator_prompt)
            
            # Update trajectories with responses
            resp_idx = 0
            for i in range(len(trajs)):
                if len(trajs[i]) > 0:
                    eval_prompt = trajs[i][-1].get('eval_prompt', "")
                    if eval_prompt != "" and isinstance(eval_prompt, list):
                        if resp_idx < len(response_list):
                            response = response_list[resp_idx]
                            resp_idx += 1
                            trajs[i][-1]['eval_info'] = response
                            auto_eval_res = 1 if ("SUCCESS" in response and "NOT SUCCESS" not in response) else 0
                            trajs[i][-1]['reward'] = auto_eval_res
        except Exception as e:
            print(f"Error in get_reward: {e}")
            import traceback
            traceback.print_exc()
            
        return trajs

    def get_log_prob(self, messages, actions):
        """
        Compute log probabilities for actions with proper gradient flow and teacher-forcing shift.
        """ 
        # messages = self.get_gemma_hf_prompts(
        #     messages,
        #     system_prompt=SYSTEM_PROMPT,
        #     transfer_image_to_base64=False
        # )
        
        messages = self.get_gemma_hf_prompts(messages, transfer_image_to_base64=False)

        # Apply chat template with proper multimodal handling.
        texts = [self.processor.apply_chat_template(message, tokenize=False) for message in messages]
        images = [self._process_vision_info(message) for message in messages]
        
        # Ensure each action ends with an EOS token.
        for i in range(len(actions)):
            actions[i] = "<start_of_turn>model\n" + actions[i] + "<end_of_turn>"
        
        # Append the actions to the texts.
        for i in range(len(texts)):
            texts[i] += actions[i]
        
        # Tokenize the whole batch (texts + images).
        try:
            batch = self.processor(
                text=texts, 
                images=images, 
                return_tensors="pt", 
                padding=True
            ).to(self.train_model.device)
        except Exception as e:
            print(f"Error during tokenization: {e}")
            print(f"message: {messages}")
        
        # Also tokenize actions on their own to later know lengths and create a mask for valid tokens.
        tokenized_actions = self.processor.tokenizer(actions, padding=True)
        tokenized_actions_mask = torch.tensor(tokenized_actions["attention_mask"]).to(self.train_model.device)
        
        # Forward pass (with gradient tracking).
        outputs = self.train_model(**batch)
        logits = outputs.logits  # Shape: [batch_size, seq_length, vocab_size]
        
        # Scale logits by temperature and compute log softmax over the vocabulary.
        logits = logits / self.temperature
        log_probs = F.log_softmax(logits, dim=-1)
        
        # Determine the maximum action length (from the tokenized actions).
        max_action_length = tokenized_actions_mask.size(1)
        
        # The appended actions are at the end of the sequence.
        # For teacher forcing, we use the logits corresponding to each token position 
        # (except the very last one) to predict the next token.
        teacher_forcing_logits = log_probs[:, -max_action_length:-1, :]  
        # teacher_forcing_logits has shape: [batch_size, max_action_length - 1, vocab_size]
        
        # Prepare teacher forcing targets by shifting the tokenized actions by one (i.e. dropping the first token).
        # This aligns predictions at time t with target token at t+1.
        action_input_ids = torch.tensor(tokenized_actions["input_ids"]).to(self.train_model.device)
        teacher_forcing_targets = action_input_ids[:, 1:]  # Shape: [batch_size, max_action_length - 1]
        
        # Create a mask of tokens to exclude
        bsz, seq_len = tokenized_actions_mask.shape
        shifted_tokenized_actions_mask = torch.zeros_like(tokenized_actions_mask, dtype=tokenized_actions_mask.dtype, device=tokenized_actions_mask.device)
        shift_amount = 4
        shifted_tokenized_actions_mask[:, shift_amount:] = tokenized_actions_mask[:, :(seq_len - shift_amount)]
        teacher_forcing_mask = shifted_tokenized_actions_mask[:, 1:]  # Shape: [batch_size, max_action_length - 1]
                
        # Gather the log probabilities corresponding to the teacher forcing targets.
        action_log_probs = teacher_forcing_logits.gather(
            dim=2, index=teacher_forcing_targets.unsqueeze(-1)
        ).squeeze(-1)
        # Now, action_log_probs has shape: [batch_size, max_action_length - 1]
        
        # Zero out the padded tokens and sum the log probabilities for each example.
        summed_log_probs = (action_log_probs * teacher_forcing_mask).sum(dim=1)
        
        # Divide by the number of valid (non-padded) tokens to get the average log probability.
        valid_token_counts = teacher_forcing_mask.sum(dim=1)
        avg_log_probs = summed_log_probs / valid_token_counts

        return avg_log_probs

        
    def _process_vision_info(self, messages: list[dict]) -> list[Image.Image]:
        image_inputs = []
        for msg in messages:
            content = msg.get("content", [])
            if not isinstance(content, list):
                content = [content]

            for element in content:
                if isinstance(element, dict) and ("image" in element or element.get("type") == "image"):
                    image_path = element['image']
                    image = Image.open(image_path)
                    image_inputs.append(image.convert("RGB"))
        return image_inputs
        